{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.1 模型构造"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.2.0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.1.1 继承`Module`类来构造模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    # 声明带有模型参数的层，这里声明了两个全连接层\n",
    "    def __init__(self, **kwargs):\n",
    "        # 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数\n",
    "        # 参数，如“模型参数的访问、初始化和共享”一节将介绍的模型参数params\n",
    "        super(MLP, self).__init__(**kwargs)\n",
    "        self.hidden = nn.Linear(784, 256) # 隐藏层\n",
    "        self.act = nn.ReLU()\n",
    "        self.output = nn.Linear(256, 10)  # 输出层\n",
    "         \n",
    "\n",
    "    # 定义模型的前向计算，即如何根据输入x计算返回所需要的模型输出\n",
    "    def forward(self, x):\n",
    "        a = self.act(self.hidden(x))\n",
    "        return self.output(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP(\n",
      "  (hidden): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (act): ReLU()\n",
      "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0234, -0.2646, -0.1168, -0.2127,  0.0884, -0.0456,  0.0811,  0.0297,\n",
       "          0.2032,  0.1364],\n",
       "        [ 0.1479, -0.1545, -0.0265, -0.2119, -0.0543, -0.0086,  0.0902, -0.1017,\n",
       "          0.1504,  0.1144]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.rand(2, 784)\n",
    "net = MLP()\n",
    "print(net)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.1.2 `Module`的子类\n",
    "### 4.1.2.1 `Sequential`类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class MySequential(nn.Module):\n",
    "    from collections import OrderedDict\n",
    "    def __init__(self, *args):\n",
    "        super(MySequential, self).__init__()\n",
    "        if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict\n",
    "            for key, module in args[0].items():\n",
    "                self.add_module(key, module)  # add_module方法会将module添加进self._modules(一个OrderedDict)\n",
    "        else:  # 传入的是一些Module\n",
    "            for idx, module in enumerate(args):\n",
    "                self.add_module(str(idx), module)\n",
    "    def forward(self, input):\n",
    "        # self._modules返回一个 OrderedDict，保证会按照成员添加时的顺序遍历成\n",
    "        for module in self._modules.values():\n",
    "            input = module(input)\n",
    "        return input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MySequential(\n",
      "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1273,  0.1642, -0.1060,  0.1401,  0.0609, -0.0199, -0.0140, -0.0588,\n",
       "          0.1765, -0.1296],\n",
       "        [ 0.0267,  0.1670, -0.0626,  0.0744,  0.0574,  0.0413,  0.1313, -0.1479,\n",
       "          0.0932, -0.0615]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MySequential(\n",
    "        nn.Linear(784, 256),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(256, 10), \n",
    "        )\n",
    "print(net)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1.2.2 `ModuleList`类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=256, out_features=10, bias=True)\n",
      "ModuleList(\n",
      "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])\n",
    "net.append(nn.Linear(256, 10)) # # 类似List的append操作\n",
    "print(net[-1])  # 类似List的索引访问\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# net(torch.zeros(1, 784)) # 会报NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class MyModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyModule, self).__init__()\n",
    "        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])\n",
    "\n",
    "    def forward(self, x):\n",
    "        # ModuleList can act as an iterable, or be indexed using ints\n",
    "        for i, l in enumerate(self.linears):\n",
    "            x = self.linears[i // 2](x) + l(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "net1:\n",
      "torch.Size([10, 10])\n",
      "torch.Size([10])\n",
      "net2:\n"
     ]
    }
   ],
   "source": [
    "class Module_ModuleList(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Module_ModuleList, self).__init__()\n",
    "        self.linears = nn.ModuleList([nn.Linear(10, 10)])\n",
    "    \n",
    "class Module_List(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Module_List, self).__init__()\n",
    "        self.linears = [nn.Linear(10, 10)]\n",
    "\n",
    "net1 = Module_ModuleList()\n",
    "net2 = Module_List()\n",
    "\n",
    "print(\"net1:\")\n",
    "for p in net1.parameters():\n",
    "    print(p.size())\n",
    "\n",
    "print(\"net2:\")\n",
    "for p in net2.parameters():\n",
    "    print(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1.2.3 `ModuleDict`类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=784, out_features=256, bias=True)\n",
      "Linear(in_features=256, out_features=10, bias=True)\n",
      "ModuleDict(\n",
      "  (act): ReLU()\n",
      "  (linear): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.ModuleDict({\n",
    "    'linear': nn.Linear(784, 256),\n",
    "    'act': nn.ReLU(),\n",
    "})\n",
    "net['output'] = nn.Linear(256, 10) # 添加\n",
    "print(net['linear']) # 访问\n",
    "print(net.output)\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# net(torch.zeros(1, 784)) # 会报NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.1.3 构造复杂的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class FancyMLP(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(FancyMLP, self).__init__(**kwargs)\n",
    "        \n",
    "        self.rand_weight = torch.rand((20, 20), requires_grad=False) # 不可训练参数（常数参数）\n",
    "        self.linear = nn.Linear(20, 20)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.linear(x)\n",
    "        # 使用创建的常数参数，以及nn.functional中的relu函数和mm函数\n",
    "        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)\n",
    "        \n",
    "        # 复用全连接层。等价于两个全连接层共享参数\n",
    "        x = self.linear(x)\n",
    "        # 控制流，这里我们需要调用item函数来返回标量进行比较\n",
    "        while x.norm().item() > 1:\n",
    "            x /= 2\n",
    "        if x.norm().item() < 0.8:\n",
    "            x *= 10\n",
    "        return x.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FancyMLP(\n",
      "  (linear): Linear(in_features=20, out_features=20, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(0.8907, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.rand(2, 20)\n",
    "net = FancyMLP()\n",
    "print(net)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): NestMLP(\n",
      "    (net): Sequential(\n",
      "      (0): Linear(in_features=40, out_features=30, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (1): Linear(in_features=30, out_features=20, bias=True)\n",
      "  (2): FancyMLP(\n",
      "    (linear): Linear(in_features=20, out_features=20, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(-0.4605, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(NestMLP, self).__init__(**kwargs)\n",
    "        self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU()) \n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())\n",
    "\n",
    "X = torch.rand(2, 40)\n",
    "print(net)\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
